{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ba85ab13",
   "metadata": {},
   "source": [
    "### Misexpression metrics \n",
    "\n",
    "* Number and proportion of misexpression events across different z-score cutoffs \n",
    "* Number and proportion of genes with at least one misexpression event across different cutoffs \n",
    "* Number and proportion of samples that have a misexpression events across different \n",
    "* Number of misexpression events in each z-score bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce2f3e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8de8a319",
   "metadata": {},
   "outputs": [],
   "source": [
    "wkdir = \"/lustre/scratch126/humgen/projects/interval_rna/interval_rna_seq/thomasVDS/misexpression_v3/\"\n",
    "wkdir_path = Path(wkdir)\n",
    "\n",
    "# inputs \n",
    "input_dir= wkdir_path.joinpath(\"2_misexp_qc\")\n",
    "input_path = Path(input_dir)\n",
    "output_dir_path = input_path.joinpath(\"misexp_metrics\")\n",
    "output_dir_path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# variables \n",
    "mixexp_tpm_cutoff = 0.5\n",
    "zscore_cutoffs = [2, 10, 20, 30, 40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "23640920",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_path = Path(input_dir)\n",
    "output_dir_path = input_path.joinpath(\"misexp_metrics\")\n",
    "output_dir_path.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d632f129",
   "metadata": {},
   "outputs": [],
   "source": [
    "## load gene expression matrix\n",
    "tpm_zscore_flat_path = input_path.joinpath(\"misexp_gene_cov_corr/tpm_zscore_4568smpls_8610genes_flat_misexp_corr_qc.csv\")\n",
    "tpm_zscore_flat_df = pd.read_csv(tpm_zscore_flat_path) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2833e7cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of inactive genes passing QC: 8650\n"
     ]
    }
   ],
   "source": [
    "# all inactive genes passing QC \n",
    "inactive_gene_id_pass_qc_path = wkdir_path.joinpath(\"2_misexp_qc/misexp_gene_cov_corr/gene_id_post_tech_cov_qc_8650.txt\")\n",
    "inactive_gene_id_pass_qc = pd.read_csv(inactive_gene_id_pass_qc_path, sep=\"\\t\", header=None)[0].tolist()\n",
    "num_inactive_gene_id_pass_qc = len(inactive_gene_id_pass_qc)\n",
    "print(f\"Number of inactive genes passing QC: {len(inactive_gene_id_pass_qc)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fd8d42aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of genes in gene expression matrix: 8610\n",
      "Number of samples in gene expression matrix: 4568\n"
     ]
    }
   ],
   "source": [
    "genes_pass_qc = tpm_zscore_flat_df.gene_id.unique()\n",
    "gene_number = len(genes_pass_qc)\n",
    "print(f\"Number of genes in gene expression matrix: {gene_number}\") # removed 40 with TPM = 0 across all samples\n",
    "smpl_pass_qc = tpm_zscore_flat_df.rna_id.unique()\n",
    "smpl_number = len(smpl_pass_qc)\n",
    "print(f\"Number of samples in gene expression matrix: {smpl_number}\")\n",
    "\n",
    "# write inactive genes with z-scores to file \n",
    "inactive_genes_path = output_dir_path.joinpath(f\"inactive_genes_{gene_number}.txt\")\n",
    "with open(inactive_genes_path, 'w') as f_out:\n",
    "    for gene in genes_pass_qc: \n",
    "        f_out.write(f\"{gene}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "023aaede",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of gene-sample pairs: 39513200\n",
      "Total number of gene-sample pairs with TPM > 0.1: 363951\n"
     ]
    }
   ],
   "source": [
    "# total number of gene-sample pairs \n",
    "total_events = len(inactive_gene_id_pass_qc) * len(smpl_pass_qc)\n",
    "print(f\"Total number of gene-sample pairs: {total_events}\")\n",
    "\n",
    "# number of gene-sample pairs with a TPM > 0.1\n",
    "total_events_grtr_tpm1 = tpm_zscore_flat_df[tpm_zscore_flat_df.TPM > 0.1].shape[0]\n",
    "print(f\"Total number of gene-sample pairs with TPM > 0.1: {total_events_grtr_tpm1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3a1e8dd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Z-score cutoff: 2\n",
      "\tNumber of misexpression events: 28956\n",
      "\tNumber of misexpressed genes: 4437\n",
      "\tNumber of never misexpressed genes: 4213\n",
      "\tNumber of samples with a misexpression event: 4386/4568\n",
      "\tPercentage of samples with misexpression event: 96.01576182136601\n",
      "\tPercentage of genes-sample pairs misexpressed: 0.0732818399927113%\n",
      "\tPercentage of genes with at least one misexpression event: 51.29479768786127%\n",
      "\tMedian number of misexpression events: 4.0\n",
      "Z-score cutoff: 10\n",
      "\tNumber of misexpression events: 17461\n",
      "\tNumber of misexpressed genes: 4437\n",
      "\tNumber of never misexpressed genes: 4213\n",
      "\tNumber of samples with a misexpression event: 3511/4568\n",
      "\tPercentage of samples with misexpression event: 76.86077057793345\n",
      "\tPercentage of genes-sample pairs misexpressed: 0.04419029590111659%\n",
      "\tPercentage of genes with at least one misexpression event: 51.29479768786127%\n",
      "\tMedian number of misexpression events: 2.0\n",
      "Z-score cutoff: 20\n",
      "\tNumber of misexpression events: 7495\n",
      "\tNumber of misexpressed genes: 3891\n",
      "\tNumber of never misexpressed genes: 4759\n",
      "\tNumber of samples with a misexpression event: 2014/4568\n",
      "\tPercentage of samples with misexpression event: 44.08931698774081\n",
      "\tPercentage of genes-sample pairs misexpressed: 0.018968344755676585%\n",
      "\tPercentage of genes with at least one misexpression event: 44.982658959537574%\n",
      "\tMedian number of misexpression events: 0.0\n",
      "Z-score cutoff: 30\n",
      "\tNumber of misexpression events: 3154\n",
      "\tNumber of misexpressed genes: 2610\n",
      "\tNumber of never misexpressed genes: 6040\n",
      "\tNumber of samples with a misexpression event: 1113/4568\n",
      "\tPercentage of samples with misexpression event: 24.365148861646237\n",
      "\tPercentage of genes-sample pairs misexpressed: 0.00798214267637144%\n",
      "\tPercentage of genes with at least one misexpression event: 30.173410404624278%\n",
      "\tMedian number of misexpression events: 0.0\n",
      "Z-score cutoff: 40\n",
      "\tNumber of misexpression events: 1218\n",
      "\tNumber of misexpressed genes: 1205\n",
      "\tNumber of never misexpressed genes: 7445\n",
      "\tNumber of samples with a misexpression event: 599/4568\n",
      "\tPercentage of samples with misexpression event: 13.112959719789844\n",
      "\tPercentage of genes-sample pairs misexpressed: 0.0030825141977870686%\n",
      "\tPercentage of genes with at least one misexpression event: 13.930635838150291%\n",
      "\tMedian number of misexpression events: 0.0\n"
     ]
    }
   ],
   "source": [
    "### calculate misexpression metrics \n",
    "\n",
    "misexp_metrics_dict = {}\n",
    "count_misexp_per_smpl_df_list = []\n",
    "\n",
    "for row, zscore in enumerate(zscore_cutoffs): \n",
    "    print(f\"Z-score cutoff: {zscore}\")\n",
    "    # identify misexpression events at z-score cutoff \n",
    "    misexp_events_df = tpm_zscore_flat_df[(tpm_zscore_flat_df.TPM > mixexp_tpm_cutoff) &\n",
    "                                          (tpm_zscore_flat_df[\"z-score\"] > zscore)]\n",
    "    \n",
    "    total_misexp_events = misexp_events_df.shape[0]\n",
    "    print(f\"\\tNumber of misexpression events: {total_misexp_events}\")\n",
    "    \n",
    "    # misexpressed genes and non-misexpressed genes at z-score cutoff \n",
    "    misexp_genes = misexp_events_df.gene_id.unique()\n",
    "    num_misexp_genes = len(misexp_genes)\n",
    "    print(f\"\\tNumber of misexpressed genes: {num_misexp_genes}\")\n",
    "    misexp_gene_path = output_dir_path.joinpath(f\"misexp_genes_tpm{mixexp_tpm_cutoff}_z{zscore}.txt\")\n",
    "    with open(misexp_gene_path, 'w') as misexp_genes_out:\n",
    "        for gene_id in misexp_genes: \n",
    "            misexp_genes_out.write(f\"{gene_id}\\n\")\n",
    "\n",
    "    # note using all 8650 inactive genes not just genes in matrix \n",
    "    never_misexp_genes = [gene_id for gene_id in inactive_gene_id_pass_qc if gene_id not in misexp_genes]\n",
    "    num_never_misexp_genes = len(never_misexp_genes)\n",
    "    print(f\"\\tNumber of never misexpressed genes: {num_never_misexp_genes}\")\n",
    "    never_misexp_gene_path = output_dir_path.joinpath(f\"never_misexp_genes_tpm{mixexp_tpm_cutoff}_z{zscore}.txt\")\n",
    "    with open(never_misexp_gene_path, 'w') as never_misexp_genes_out:\n",
    "        for gene_id in never_misexp_genes: \n",
    "            never_misexp_genes_out.write(f\"{gene_id}\\n\")\n",
    "    \n",
    "    # percentage of samples with a misexpression event \n",
    "    smpls_with_misexp = misexp_events_df.rna_id.unique()\n",
    "    num_smpls_with_misexp = len(smpls_with_misexp)\n",
    "    perc_smpls_with_misexp = (num_smpls_with_misexp/smpl_number) * 100\n",
    "    print(f\"\\tNumber of samples with a misexpression event: {num_smpls_with_misexp}/{smpl_number}\")\n",
    "    print(f\"\\tPercentage of samples with misexpression event: {perc_smpls_with_misexp}\")\n",
    "    # percentage of gene-sample pairs misexpressed \n",
    "    pc_gene_smpl_misexp = (total_misexp_events/total_events) * 100\n",
    "    print(f\"\\tPercentage of genes-sample pairs misexpressed: {pc_gene_smpl_misexp}%\")\n",
    "    # percentage of genes with at least one misexpression event \n",
    "    pc_genes_misexp = (num_misexp_genes/num_inactive_gene_id_pass_qc) * 100\n",
    "    print(f\"\\tPercentage of genes with at least one misexpression event: {pc_genes_misexp}%\")\n",
    "    \n",
    "    # median misexpression events per sample \n",
    "    count_misexp_per_smpl_df = pd.DataFrame(misexp_events_df.groupby(by=\"rna_id\")[\"gene_id\"].count())\n",
    "    count_misexp_per_smpl_df = count_misexp_per_smpl_df.rename(columns={\"gene_id\":\"misexp_count\"})\n",
    "    rna_id_zero_count_dict = {rna_id:0 for rna_id in smpl_pass_qc if rna_id not in smpls_with_misexp}\n",
    "    rna_id_zero_count_df = pd.DataFrame.from_dict(rna_id_zero_count_dict, orient=\"index\")\n",
    "    rna_id_zero_count_df = rna_id_zero_count_df.rename(columns={0:\"misexp_count\"})\n",
    "    count_misexp_per_smpl_df = pd.concat([count_misexp_per_smpl_df, rna_id_zero_count_df])\n",
    "    if count_misexp_per_smpl_df.misexp_count.sum() != total_misexp_events:\n",
    "        raise ValueError(f\"Sum of misexpression events per sample is not equal to total misexpression events.\")\n",
    "    median_misexp = count_misexp_per_smpl_df.misexp_count.median()\n",
    "    count_misexp_per_smpl_df[\"zscore\"] = f\"> {zscore}\"\n",
    "    count_misexp_per_smpl_df_list.append(count_misexp_per_smpl_df)\n",
    "    print(f\"\\tMedian number of misexpression events: {median_misexp}\")\n",
    "    \n",
    "    # add information \n",
    "    misexp_metrics_dict[row] = [zscore, num_smpls_with_misexp, perc_smpls_with_misexp, total_misexp_events,\n",
    "                                pc_gene_smpl_misexp, num_misexp_genes, pc_genes_misexp, median_misexp]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "db8fcc9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "misexp_metrics_cols = [\"zscore\", \"smpl_misexp\", \"smpl_misexp_perc\", \"gene_smpl_misexp\",\n",
    "                       \"gene_smpl_misexp_perc\", \"gene_misexp\", \"gene_misexp_perc\", \"median_misexp\"]\n",
    "misexp_metrics_df = pd.DataFrame.from_dict(misexp_metrics_dict, orient=\"index\", columns=misexp_metrics_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2dd4b49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# write results to file\n",
    "misexp_metrics_path = output_dir_path.joinpath(\"misexp_metrics.csv\")\n",
    "misexp_metrics_df.to_csv(misexp_metrics_path, index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d0f14b",
   "metadata": {},
   "source": [
    "### Number of misexpression events per z-score bin "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "bf2bd681",
   "metadata": {},
   "outputs": [],
   "source": [
    "misexp_tpm_zscore_df = tpm_zscore_flat_df[(tpm_zscore_flat_df.TPM > mixexp_tpm_cutoff) & \n",
    "                                          (tpm_zscore_flat_df[\"z-score\"] > 2)\n",
    "                                         ].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e31d8d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "bins = [bin_name for bin_name in range(2, 69)]\n",
    "bin_names = [f\"{bin_name}-{bin_name+1}\" for bin_name in range(2, 68)]\n",
    "misexp_tpm_zscore_df[\"z-score_bins\"] = pd.cut(misexp_tpm_zscore_df[\"z-score\"], bins=bins, labels=bin_names) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e5fafab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "count_misexp_by_bin_df = misexp_tpm_zscore_df.groupby([\"z-score_bins\"], as_index=False).gene_id.count()\n",
    "count_misexp_by_bin_df = count_misexp_by_bin_df.rename(columns={\"gene_id\": \"misexp_count\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "ce091ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# write to file \n",
    "misexp_events_per_zscore_bin_path = output_dir_path.joinpath(\"misexp_events_per_zscore_bin.tsv\")\n",
    "count_misexp_by_bin_df.to_csv(misexp_events_per_zscore_bin_path, sep=\"\\t\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ac2abf6",
   "metadata": {},
   "source": [
    "### Median misexpression events per sample (across z-score cutoffs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "09539139",
   "metadata": {},
   "outputs": [],
   "source": [
    "count_misexp_per_smpl_all_cutoffs_df = pd.concat(count_misexp_per_smpl_df_list)\n",
    "count_misexp_per_smpl_all_cutoffs_df = count_misexp_per_smpl_all_cutoffs_df.reset_index().rename(columns={\"index\": \"rna_id\"})\n",
    "# write to file \n",
    "count_misexp_per_smpl_all_cutoffs_path = output_dir_path.joinpath(\"misexp_events_per_smpl_zscore.tsv\")\n",
    "count_misexp_per_smpl_all_cutoffs_df.to_csv(count_misexp_per_smpl_all_cutoffs_path, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49ce990c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
